import tensorflow as tf
#minist引入数据方法
from tensorflow.examples.tutorials.mnist import input_data
import random
import matplotlib.pyplot as plt
tf.set_random_seed(777) #设置随机种子

# The MNIST data is split into three parts:
# 55,000 data points of training data (mnist.train)
# 10,000 points of test data (mnist.test), and
# 5,000 points of validation data (mnist.validation).

'''
手写数字识别
'''

# Each image is 28 pixels by 28 pixels
#定义数据集，label定义为亚编码格式
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
nb_classes = 10
#定义占位符
X = tf.placeholder("float", shape=[None, 784])#28*28像素的图片
# labels是每张图片都对应一个one-hot的10个值的向量
Y = tf.placeholder(tf.float32, [None, nb_classes])
#权重和偏置
W = tf.Variable(tf.random_normal([784, nb_classes]), name='weight')
b = tf.Variable(tf.random_normal([nb_classes]), name='bias')
#预测模型
# hypothesis = tf.nn.softmax(tf.matmul(X, W) + b)
logits = tf.matmul(X, W) + b
# hypothesis = logits
# y_ = tf.nn.softmax(logits)
#代价或损失函数
# cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(y_), axis=1))
#tf.nn.softmax_cross_entropy_with_logits 1.激活函数softmax  2.求交叉熵
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
# 梯度下降优化器
train = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)
#准确率计算
prediction = tf.argmax(logits, 1)
correct_prediction = tf.equal(prediction, tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#创建会话
sess = tf.Session()
sess.run(tf.global_variables_initializer()) #全局变量初始化
#迭代训练
training_epochs = 15
batch_size = 100
'''
批量处理
总共将数据拟合15遍55,000
'''
for epoch in range(training_epochs):
    avg_cost = 0#平均损失
    #55,000/100 = 550
    total_batch = int(mnist.train.num_examples / batch_size)
    for i in range(total_batch):#550
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)#获取下一组数据集
        c, _ = sess.run([cost, train], feed_dict={X: batch_xs, Y: batch_ys})
        avg_cost += c / total_batch
    # 显示损失值收敛情况
    print(epoch, avg_cost)
#准确率
print("Accuracy: ", sess.run(accuracy, feed_dict={X: mnist.test.images[:5000], Y: mnist.test.labels[:5000]}))
#在测试集中随机抽一个样本进行测试
r = random.randint(0, mnist.test.num_examples - 1)
print("Label: ", sess.run(tf.argmax(mnist.test.labels[r:r + 1], 1)))
print("Prediction: ", sess.run(tf.argmax(logits, 1), feed_dict={X: mnist.test.images[r:r + 1]}))
plt.imshow(mnist.test.images[r: r + 1].reshape(28, 28), cmap='Greys')
plt.show()
while True:
    str = input()
    try:
        if str == 'q':
            break
        r = random.randint(0, mnist.test.num_examples - 1)
        print("Label: ", sess.run(tf.argmax(mnist.test.labels[r:r + 1], 1)))
        print("Prediction: ", sess.run(tf.argmax(logits, 1), feed_dict={X: mnist.test.images[r:r + 1]}))
        plt.imshow(mnist.test.images[r:r + 1].reshape(28, 28), cmap='Greys')
        plt.show()
    except:
        continue
'''
0 2.8263026701320304
1 1.06166895113208
2 0.8380613085898487
3 0.7332327307354319
4 0.6692798727750779
5 0.624611818573691
6 0.5911603328856556
7 0.5638689751245757
8 0.5417451611703092
9 0.522673569335179
10 0.506782321062955
11 0.49244763639840267
12 0.47995582400397785
13 0.4688936629078603
14 0.458703470866789
'''